{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Results dataset from prompts after Inversion\n",
    "import pandas as pd\n",
    "compare_promts = pd.read_parquet('XXX')\n",
    "compare_promts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bert Scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bert_score import score\n",
    "\n",
    "def bert_sim_eval(candidate_sentence, target_sentence):\n",
    "    # Put sentences in lists since bert-score's score function expects list inputs\n",
    "    candidates = [candidate_sentence]\n",
    "    references = [target_sentence]\n",
    "\n",
    "    # Calculate BertScore\n",
    "    P, R, F1 = score(candidates, references, lang='en', verbose=True)\n",
    "    return {\"P\":P.item(),\"R\":R.item(),\"F1\":F1.item()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "P_ours = []\n",
    "R_ours = []\n",
    "F1_ours = []\n",
    "for index,item in tqdm(compare_promts.iterrows(),total=len(compare_promts)):    \n",
    "    res = bert_sim_eval(item['rl_generation'],item['reference_prompt'])\n",
    "    P_ours.append(res['P'])\n",
    "    R_ours.append(res['R'])\n",
    "    F1_ours.append(res['F1'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sum(P_ours)/len(P_ours))\n",
    "print(sum(R_ours)/len(R_ours))\n",
    "print(sum(F1_ours)/len(F1_ours))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CLIP: Calculate similarity between target image and images generated from prompts using two different methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import open_clip\n",
    "import torch\n",
    "from PIL import Image\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(\"ViT-L-14\", pretrained=\"openai\", device=device)\n",
    "\n",
    "def measure_clip_imgtxt_similarity(image_path, txt_1):\n",
    "    text = open_clip.tokenize([txt_1]).to(device)\n",
    "    orig_images_t = clip_preprocess(Image.open(image_path)).unsqueeze(0).to(device)\n",
    "    with torch.no_grad():\n",
    "        image_features = clip_model.encode_image(orig_images_t)\n",
    "        text_features = clip_model.encode_text(text)\n",
    "        \n",
    "    cosine_similarity = torch.nn.functional.cosine_similarity(image_features, text_features, dim=1)\n",
    "\n",
    "    return cosine_similarity.cpu().numpy().tolist()\n",
    "\n",
    "    \n",
    "\n",
    "def measure_clip_imgs_similarity(orig_images_t, pred_imgs_t, clip_model):\n",
    "    with torch.no_grad():\n",
    "        orig_feat = clip_model.encode_image(orig_images_t)\n",
    "        orig_feat = orig_feat / orig_feat.norm(dim=1, keepdim=True)\n",
    "\n",
    "        pred_feat = clip_model.encode_image(pred_imgs_t)\n",
    "        pred_feat = pred_feat / pred_feat.norm(dim=1, keepdim=True)\n",
    "        return (orig_feat @ pred_feat.t()).mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import StableDiffusionPipeline\n",
    "from diffusers import PNDMScheduler\n",
    "from PIL import Image\n",
    "import torch\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "\n",
    "model_id = \"runwayml/stable-diffusion-v1-5\"\n",
    "scheduler = PNDMScheduler.from_pretrained(model_id, subfolder=\"scheduler\")\n",
    "\n",
    "pipe = StableDiffusionPipeline.from_pretrained(\n",
    "    \"runwayml/stable-diffusion-v1-5\",\n",
    "    scheduler=scheduler,\n",
    "    torch_dtype=torch.float16,\n",
    "    ).to(device)\n",
    "\n",
    "image_length = 512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "eval_sim_ours = []\n",
    "\n",
    "for index, item in tqdm(compare_promts.iterrows(),total=len(compare_promts)):\n",
    "    orig_image = Image.open(item['img_path']).convert('RGB')\n",
    "    with torch.no_grad():\n",
    "        pred_imgs = pipe(\n",
    "                    item['rl_generation'],\n",
    "                    num_images_per_prompt=1,\n",
    "                    guidance_scale=7.5,\n",
    "                    num_inference_steps=50,\n",
    "                    height=image_length,\n",
    "                    width=image_length,\n",
    "                    ).images\n",
    "        orig_images_temp = [clip_preprocess(orig_image).unsqueeze(0)]\n",
    "        orig_images_t = torch.cat(orig_images_temp).to(device)\n",
    "        pred_imgs_temp = [clip_preprocess(i).unsqueeze(0) for i in pred_imgs]\n",
    "        pred_imgs_t = torch.cat(pred_imgs_temp).to(device)\n",
    "        eval_sim_ours.append(measure_clip_imgs_similarity(orig_images_t, pred_imgs_t, clip_model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(eval_sim_ours)/len(eval_sim_ours)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LPIPS: Semantic Difference Between Target and Generated Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import lpips\n",
    "import itertools\n",
    "from PIL import Image\n",
    "from torchvision import transforms\n",
    "\n",
    "# Load LPIPS model\n",
    "lpips_model = lpips.LPIPS(net='alex').to(device)\n",
    "\n",
    "# Load and preprocess images\n",
    "transform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),])\n",
    "\n",
    "def load_image(image_path):\n",
    "    image = Image.open(image_path).convert('RGB')\n",
    "    return transform(image).unsqueeze(0)\n",
    "\n",
    "def measure_lpips_imgs_similarity(img1, img2):\n",
    "    distance = lpips_model(img1, img2)\n",
    "    # Calculate diversity metric, here we take the average\n",
    "    return distance.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lpips_sim_ours = []\n",
    "\n",
    "for index, item in tqdm(compare_promts.iterrows(),total=len(compare_promts)):\n",
    "    orig_image_t = load_image(item['img_path']).to(device)\n",
    "    with torch.no_grad():\n",
    "        pred_imgs = pipe(\n",
    "                    item['rl_generation'],\n",
    "                    num_images_per_prompt=1,\n",
    "                    guidance_scale=7.5,\n",
    "                    num_inference_steps=50,\n",
    "                    height=image_length,\n",
    "                    width=image_length,\n",
    "                    )\n",
    "        pred_img = pred_imgs.images[0]\n",
    "        pred_img_t = transform(pred_img).unsqueeze(0).to(device)\n",
    "\n",
    "        lpips_sim_ours.append(lpips_model(orig_image_t, pred_img_t).item())\n",
    "        #.append(measure_lpips_imgs_similarity())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sum(lpips_sim_ours)/len(lpips_sim_ours))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ldm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
